from torch2trt.trt_engine import TensorRTEngine, init


class Inference:
    def __init__(self):
        path = '/UseModel/' + 'prompt.trt'
        model_path = '/data0/jianyu10/PTM/huggingface_model_cache/chinese-roberta-wwm-ext'
        self.trtengine = TensorRTEngine(path, model_path)
        init()

    def inference(self, title):
        vec = self.trtengine.inference(title)
        return vec


if __name__ == '__main__':
    pass
